import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from gps.cnf.nflow import NormalizingFlow



def gen_network(n_inputs, n_outputs, hidden=(10,), activation='tanh'):

    model = nn.Sequential()
    for i in range(len(hidden)):

        # add layer
        if i == 0:
            alayer = nn.Linear(n_inputs, hidden[i])
        else:
            alayer = nn.Linear(hidden[i-1], hidden[i])
        model.append(alayer)
        model.append(nn.Dropout(0.2))

        # add activation
        if activation == 'tanh':
            act = nn.Tanh()
        elif activation == 'relu':
            act = nn.ReLU()
        elif activation == 'leakyrelu':
            act = nn.LeakyReLU()
        else:
            act = nn.ReLU()
        model.append(act)

    # output layer
    model.append(nn.Linear(hidden[-1], n_outputs))

    return model



class CNFLayer(nn.Module):
    

    def __init__(self, DEVICE, var_size, cond_size, mask, hidden=(10,), activation='tanh'):
        super(CNFLayer, self).__init__()        
        
        self.mask = mask.to(DEVICE)
        self.nn_t = gen_network(var_size + cond_size, var_size, hidden, activation)
        self.nn_s = gen_network(var_size + cond_size, var_size, hidden, activation)


    def f(self, X, C=None):
        
        if C is not None:
            XC = torch.cat((X * self.mask[None, :], C), dim=1)
        else:
            XC = X * self.mask[None, :]

        T = self.nn_t(XC)
        S = self.nn_s(XC)

        X_new = (X * torch.exp(S) + T) * (1 - self.mask[None, :]) + X * self.mask[None, :]
        log_det = (S * (1 - self.mask[None, :])).sum(dim=-1)
        return X_new, log_det


    def g(self, X, C=None):
        
        if C is not None:
            XC = torch.cat((X * self.mask[None, :], C), dim=1)
        else:
            XC = X * self.mask[None, :]

        T = self.nn_t(XC)
        S = self.nn_s(XC)

        X_new = ((X - T) * torch.exp(-S)) * (1 - self.mask[None, :]) + X * self.mask[None, :]
        return X_new



class CNF:
    

    def __init__(self,DEVICE, n_layers=8, hidden=(10,), activation='tanh',
                       batch_size=32, n_epochs=10, lr=0.0001, weight_decay=0):
       
        self.n_layers = n_layers
        self.hidden = hidden
        self.activation = activation
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.lr = lr
        self.weight_decay = weight_decay
        self.DEVICE = DEVICE

        self.prior = None
        self.nf = None
        self.opt = None

        self.loss_history = []
        self.val_loss = []


    def _model_init(self, X, C):

        var_size = X.shape[1]
        if C is not None:
            cond_size = C.shape[1]
        else:
            cond_size = 0

        # init prior
        if self.prior is None:
            self.prior = torch.distributions.MultivariateNormal(torch.zeros(var_size, device=self.DEVICE),
                                                                torch.eye(var_size, device=self.DEVICE))
        # init NF model and optimizer
        if self.nf is None:
                
            layers = []
            for i in range(self.n_layers):
                alayer = CNFLayer(DEVICE=self.DEVICE, var_size=var_size,
                                      cond_size=cond_size,
                                      mask=((torch.arange(var_size) + i) % 2),
                                      hidden=self.hidden,
                                      activation=self.activation)
                layers.append(alayer)

            self.nf = NormalizingFlow(layers=layers, prior=self.prior).to(self.DEVICE)
            self.opt = torch.optim.Adam(self.nf.parameters(),
                                        lr=self.lr,
                                        weight_decay=self.weight_decay)
            


    def fit(self, X, C=None):
        

        # model init
        self._model_init(X, C)

        # numpy to tensor, tensor to dataset
        if C is not None:
            dataset = TensorDataset(X, C)
        else:
            dataset = TensorDataset(X)


        for epoch in range(self.n_epochs):
            for batch in DataLoader(dataset, batch_size=self.batch_size, shuffle=True):
                self.nf.train()
                X_batch = batch[0].to(self.DEVICE)

                X_batch += torch.randn_like(X_batch) * 0.1

                if C is not None:
                    C_batch = batch[1].to(self.DEVICE)
                else:
                    C_batch = None

                # caiculate loss
                loss = -self.nf.log_prob(X_batch, C_batch).mean()

                # optimization step
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()

                # caiculate and store loss
                self.loss_history.append(loss.detach().cpu())

        


    def pob(self,X, C=None):
        
        X,C = X.to(self.DEVICE), C.to(self.DEVICE)
        self.nf.eval()
        log_pob = self.nf.log_prob(X, C)
        pob = torch.exp(log_pob).cpu().detach()

        return pob

    
    def sample(self, C=100):
        
        if type(C) != type(1):
            C = torch.tensor(C, dtype=torch.float32, device=self.DEVICE)
        X = self.nf.sample(C).cpu().detach().numpy()
        return X
